#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import json
from pathlib import Path

import pandas as pd
import numpy as np


def load_config() -> dict:
    cfg_path = Path("media_project/config.json")
    if not cfg_path.exists():
        cfg_path = Path("media_project/config.example.json")
    return json.loads(cfg_path.read_text(encoding="utf-8"))


def normalize_code6(s: pd.Series) -> pd.Series:
    out = s.astype("string").str.strip()
    out = out.replace({"": pd.NA, "nan": pd.NA, "None": pd.NA})
    out = out.str.replace(r"\.0$", "", regex=True)
    out = out.str.extract(r"(\d+)", expand=False)
    out = out.str.zfill(6)
    return out


def normalize_year(s: pd.Series) -> pd.Series:
    return pd.to_numeric(s, errors="coerce").astype("Int64")


def load_co2_city_year(path: str) -> pd.DataFrame:
    df = pd.read_excel(path)
    keep = ["城市代码", "城市", "年份", "CO2排放总量_吨"]
    missing = [c for c in keep if c not in df.columns]
    if missing:
        raise RuntimeError(f"CO2 file missing columns: {missing}. Available: {list(df.columns)}")

    out = df[keep].rename(
        columns={
            "城市代码": "city_code6",
            "城市": "city_name",
            "年份": "year",
            "CO2排放总量_吨": "co2_tons",
        }
    )
    out["city_code6"] = normalize_code6(out["city_code6"])
    out["year"] = normalize_year(out["year"])
    out["co2_tons"] = pd.to_numeric(out["co2_tons"], errors="coerce")
    out = out.dropna(subset=["city_code6", "year"]).drop_duplicates(subset=["city_code6", "year"])
    out = out.sort_values(["city_code6", "year"], kind="mergesort")
    return out


def load_city_controls(path: str) -> pd.DataFrame:
    df = pd.read_excel(path, sheet_name="原始数据")
    keep = [
        "城市代码",
        "城市",
        "年份",
        "地区生产总值(万元)",
        "常住人口(万人)",
        "第二产业增加值占GDP比重(%)",
        "第三产业增加值占GDP比重(%)",
    ]
    missing = [c for c in keep if c not in df.columns]
    if missing:
        raise RuntimeError(f"City DB missing columns: {missing}. Available: {list(df.columns)}")

    out = df[keep].rename(
        columns={
            "城市代码": "city_code6",
            "城市": "city_name_citydb",
            "年份": "year",
            "地区生产总值(万元)": "gdp_wanyuan",
            "常住人口(万人)": "pop_wan",
            "第二产业增加值占GDP比重(%)": "share_secondary_pct",
            "第三产业增加值占GDP比重(%)": "share_tertiary_pct",
        }
    )
    out["city_code6"] = normalize_code6(out["city_code6"])
    out["year"] = normalize_year(out["year"])
    for c in ["gdp_wanyuan", "pop_wan", "share_secondary_pct", "share_tertiary_pct"]:
        out[c] = pd.to_numeric(out[c], errors="coerce")

    out = out.dropna(subset=["city_code6", "year"]).drop_duplicates(subset=["city_code6", "year"])
    out = out.sort_values(["city_code6", "year"], kind="mergesort")
    return out


def aggregate_aqi_daily_to_city_year(path: str, chunksize: int = 1_000_000) -> pd.DataFrame:
    usecols = ["行政区划代码", "年份", "AQI", "空气质量级别"]
    reader = pd.read_csv(
        path,
        encoding="utf-8-sig",
        usecols=usecols,
        dtype=str,
        chunksize=chunksize,
        low_memory=True,
        on_bad_lines="skip",
    )

    partial = []
    for chunk in reader:
        chunk = chunk.dropna(subset=["行政区划代码", "年份"])
        chunk = chunk.assign(
            city_code6=normalize_code6(chunk["行政区划代码"]),
            year=normalize_year(chunk["年份"]),
            aqi=pd.to_numeric(chunk["AQI"], errors="coerce"),
            good_day=chunk["空气质量级别"]
            .astype("string")
            .fillna("")
            .str.contains(r"(?:优|良)", regex=True)
            .astype("Int64"),
        ).dropna(subset=["city_code6", "year"])

        g = (
            chunk.groupby(["city_code6", "year"], as_index=False)
            .agg(
                aqi_sum=("aqi", "sum"),
                aqi_n=("aqi", "count"),
                days=("aqi", "size"),
                good_days=("good_day", "sum"),
            )
            .astype({"year": "Int64"})
        )
        partial.append(g)

    if not partial:
        return pd.DataFrame(columns=["city_code6", "year", "aqi_mean", "aqi_n", "days", "good_days"])

    agg = pd.concat(partial, ignore_index=True)
    agg = (
        agg.groupby(["city_code6", "year"], as_index=False)
        .agg(aqi_sum=("aqi_sum", "sum"), aqi_n=("aqi_n", "sum"), days=("days", "sum"), good_days=("good_days", "sum"))
        .astype({"year": "Int64"})
    )
    agg["aqi_mean"] = agg["aqi_sum"] / agg["aqi_n"].replace({0: pd.NA})
    agg = agg.drop(columns=["aqi_sum"]).sort_values(["city_code6", "year"], kind="mergesort")
    return agg


def load_pilot_panel(path: str, prefix: str) -> pd.DataFrame:
    df = pd.read_excel(path)
    if "城市代码" not in df.columns or "年份" not in df.columns:
        raise RuntimeError(f"Pilot file must include 城市代码 and 年份. Available: {list(df.columns)}")

    out = df.copy()
    out["city_code6"] = normalize_code6(out["城市代码"])
    out["year"] = normalize_year(out["年份"])

    cols = {"city_code6", "year"}
    for c in ["Treat", "Post", "DID", "试点启动年份", "低碳城市试点年份", "试点城市"]:
        if c in out.columns:
            cols.add(c)
    out = out[list(cols)]

    rename = {
        "Treat": f"{prefix}_treat",
        "Post": f"{prefix}_post",
        "DID": f"{prefix}_did",
        "试点启动年份": f"{prefix}_pilot_year",
        "低碳城市试点年份": f"{prefix}_pilot_year",
        "试点城市": f"{prefix}_treat",
    }
    out = out.rename(columns={k: v for k, v in rename.items() if k in out.columns})

    for c in [f"{prefix}_treat", f"{prefix}_post", f"{prefix}_did"]:
        if c in out.columns:
            out[c] = pd.to_numeric(out[c], errors="coerce").fillna(0).astype("Int64")
    if f"{prefix}_pilot_year" in out.columns:
        out[f"{prefix}_pilot_year"] = normalize_year(out[f"{prefix}_pilot_year"])

    if f"{prefix}_treat" in out.columns and f"{prefix}_pilot_year" in out.columns:
        if f"{prefix}_post" not in out.columns:
            out[f"{prefix}_post"] = ((out["year"] >= out[f"{prefix}_pilot_year"]) & (out[f"{prefix}_treat"] == 1)).astype(
                "Int64"
            )
        if f"{prefix}_did" not in out.columns:
            out[f"{prefix}_did"] = (out[f"{prefix}_treat"] * out[f"{prefix}_post"]).astype("Int64")

    out = out.dropna(subset=["city_code6", "year"]).drop_duplicates(subset=["city_code6", "year"])
    return out


def main():
    cfg = load_config()
    out_dir = Path(cfg.get("out_dir", "media_project/out"))
    out_dir.mkdir(parents=True, exist_ok=True)

    co2_path = cfg["co2_city_xlsx"]
    aqi_path = cfg["aqi_daily_csv"]
    city_db_path = cfg.get("city_db_xlsx")
    lowcarbon_path = cfg.get("lowcarbon_pilot_xlsx")
    carbon_trading_path = cfg.get("carbon_trading_pilot_xlsx")
    chunksize = int(cfg.get("chunksize_aqi", 1_000_000))

    co2 = load_co2_city_year(co2_path)
    aqi = aggregate_aqi_daily_to_city_year(aqi_path, chunksize=chunksize)

    panel = co2.merge(aqi, how="left", on=["city_code6", "year"])

    if city_db_path:
        citydb = load_city_controls(city_db_path)
        panel = panel.merge(citydb, how="left", on=["city_code6", "year"])
        panel["log_gdp_wanyuan"] = np.log(panel["gdp_wanyuan"])
        panel["co2_per_gdp"] = panel["co2_tons"] / panel["gdp_wanyuan"]
        panel["log_co2_per_gdp"] = np.log(panel["co2_tons"]) - np.log(panel["gdp_wanyuan"])
        panel["co2_per_capita"] = panel["co2_tons"] / (panel["pop_wan"] * 10_000.0)
        panel["log_co2_per_capita"] = np.log(panel["co2_tons"]) - np.log(panel["pop_wan"] * 10_000.0)

    if lowcarbon_path:
        low = load_pilot_panel(lowcarbon_path, prefix="lowcarbon")
        panel = panel.merge(low, how="left", on=["city_code6", "year"])

    if carbon_trading_path:
        ct = load_pilot_panel(carbon_trading_path, prefix="carbon_trading")
        panel = panel.merge(ct, how="left", on=["city_code6", "year"])

    out_path = out_dir / "env_policy_city_year.csv"
    panel.to_csv(out_path, index=False, encoding="utf-8-sig")
    print(f"Wrote {out_path} rows={len(panel)} cities={panel['city_code6'].nunique()} years={panel['year'].nunique()}")
    print("Note: script only reads raw files and writes to out_dir; it does not modify the source archives.")


if __name__ == "__main__":
    main()
